package neuroph;

import cn.hutool.core.date.DateUtil;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.ybg.share.ShareApplication;
import com.ybg.share.core.dbapi.entity.ShareStockIndex;
import com.ybg.share.core.dbapi.service.ShareCompositeIndexService;
import com.ybg.share.framework.enums.MarketEnum;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.nnet.Perceptron;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringRunner;

import java.math.BigDecimal;
import java.util.List;

@SpringBootTest(classes = ShareApplication.class)
@RunWith(SpringRunner.class)
public class StockIndex {
    @Autowired
    ShareCompositeIndexService shareCompositeIndexService;


    @Test
    public void indexGenNeuroph(){
        QueryWrapper<ShareStockIndex> wrapper= new QueryWrapper<>();
        wrapper.eq(ShareStockIndex.MARKET, MarketEnum.sh.getCode());
        wrapper.gt(ShareStockIndex.DATE,"2019-01-01");
        wrapper.orderByAsc(ShareStockIndex.DATE);

        List<ShareStockIndex> list = shareCompositeIndexService.list(wrapper);
        NeuralNetwork neuralNetwork = new Perceptron(2, 1);
        DataSet trainingSet = new DataSet(2, 1);
        for (ShareStockIndex shareStockIndex : list) {
            System.out.println("正在学习"+DateUtil.now());
            double[] result;
            if(shareStockIndex.getChangeAmount().compareTo(BigDecimal.ZERO)>0){
                result=new double[]{1};
            }else{
                result=new double[]{0};
            }
            trainingSet. add (new DataSetRow (new double[]{shareStockIndex.getTradeNum().doubleValue()/1000000d,shareStockIndex.getTradeMoney().doubleValue()/1000000000d  },result  ));


        }
        System.out.println("循环结束"+DateUtil.now());
        neuralNetwork.learn(trainingSet);
        System.out.println("学习结束"+DateUtil.now());
        neuralNetwork.save("/test/shareText1.nnet");
        neuralNetwork.setInput(180, 166.0000);
        neuralNetwork.calculate();
        System.out.println("准备输出结果"+DateUtil.now());
        double[] networkOutput = neuralNetwork.getOutput();
        System.out.println( "结果="+ ( networkOutput[0]==0d?"跌或持平":"涨"));

    }
}
